"""
    © 2021 This work is licensed under a CC-BY-NC-SA license.
    Title: *"Behavioral cloning in recurrent spiking networks: A comprehensive framework"*
    Authors: Anonymus
"""

import json
import pickle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm


from os import listdir

from matplotlib import rc
from matplotlib.markers import MarkerStyle
from matplotlib.collections import LineCollection
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

rc('font', size = 14)

def plt_trj(agent, T = 80):
    atraj = agent.reshape (-1, 1, 2)
  
    atraj = np.concatenate ((atraj[:-1], atraj[1:]), axis = 1)
  

    norm = plt.Normalize (0, T)

    t = np.linspace (0, 60, num = T)
    mlw = np.linspace (0.5, 2.5, num = T)
    alc = LineCollection (atraj, cmap = 'coolwarm', norm = norm, lw = mlw)

    alc.set_array (t)

    return alc

def BlockB(fig, ax, traj, r, max_r = 20, rt = .7):
    cmap = cm.get_cmap ('RdYlGn')
    cols = np.linspace (0., 1., num = 120)

    cols = (r - np.min(r)) / (max_r - np.min(r))

    TR   = np.array([traj[np.floor (testset) == t] for t in trainset]).squeeze()
    cols = np.array([cols[np.floor (testset) == t] for t in trainset]).squeeze()

    [ax.plot (*trj.T, c = cmap(c), lw = 2) for trj, c in zip(TR, cols)]
    ax.scatter (rt * np.cos (test_theta[::2]), rt * np.sin(test_theta[::2]), marker = '*', color = 'indigo')

    for theta in train_theta: 
        t = MarkerStyle(marker='$\u27A4$')
        t._transform = t.get_transform().rotate(theta + np.pi)

        _rt = rt + 0.07
        ax.scatter (_rt * np.cos (theta), _rt * np.sin(theta), marker = t, color = 'm', s = 60)


    norm = Normalize(vmin = 0, vmax = 1)
    cb = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax = ax, ticks = [0, 1])
    cb.ax.set_yticklabels (['bad', 'good'])
    cb.ax.set_ylabel ('final reward', labelpad = -20)

    ax.set_xticks ([-.6, .6])
    ax.set_yticks ([-.2, .8])
    ax.set_xticklabels ([])
    ax.set_yticklabels ([])

    ax.spines['top'].set_visible (False)
    ax.spines['right'].set_visible (False)
    ax.spines['bottom'].set_bounds (-.6, .6)
    ax.spines['left'].set_bounds (-.2, .8)

    return fig, ax

def BlockC(fig, ax):
    cmap = cm.get_cmap ('copper')
    cols = [cmap (c) for c in np.linspace (0., 1., num = len(ranks))]

    for rank, c in zip (ranks, cols):
        size = data[rank].shape[0]

        rew = np.mean (data[rank], axis = 0)
        err = np.std  (data[rank], axis = 0)

        ax.plot (testset, rew, color = c)

    for theta in trainset: 
        t = MarkerStyle(marker='$\u27A4$')
        t._transform = t.get_transform().rotate(-0.5 * np.pi)

        ax.scatter (theta, 30, marker = t, color = 'm', s = 60)

    norm = Normalize(vmin = ranks[0], vmax = ranks[-1])
    cb = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax = ax, ticks = ranks)
    cb.ax.set_yticklabels (ranks)
    cb.ax.set_title ('ranks', fontsize = 14)

    ax.set_xticks ([30, 50, 70, 90, 110, 130, 150])
    ax.spines['top'].set_visible (False)
    ax.spines['right'].set_visible (False)
    ax.spines['bottom'].set_bounds (30, 150)
    ax.spines['left'].set_bounds (0, 30)

    ax.set_xlabel ('$\\theta$')
    ax.set_ylabel ('reward')

    return fig, ax


def BlockD(fig, ax):
    cmap = cm.get_cmap ('copper')
    cols = [cmap (c) for c in np.linspace (0., 1., num = len(ranks))]

    for rank1, rank2, c1, c2 in zip (ranks, ranks[1:], cols, cols[1:]):
        size1 = data[rank1].shape[0]
        size2 = data[rank2].shape[0]
        
        R1 = np.mean (data[rank1])
        R2 = np.mean (data[rank2])

        Re1 = np.std (data[rank1]) / np.sqrt (size1)
        Re2 = np.std (data[rank2]) / np.sqrt (size2)

        ax.errorbar ([rank1, rank2], [R1, R2], np.array ([Re1, Re2]), 
                    fmt = 'o', mec = 'w', ms = 8, mfc = c1, capsize = 5, ls = '-', color = c1)

    ax.set_xlabel ('ranks')
    ax.set_ylabel ('$\\langle \\mathrm{reward} \\rangle$')

    ax.set_xticks ([0, 50, 100, 150, 200, 250, 300])
    ax.spines['top'].set_visible (False)
    ax.spines['right'].set_visible (False)
    ax.spines['left'].set_bounds (0, 10)
    ax.spines['bottom'].set_bounds (ranks[0], ranks[-1])

    return fig, ax

# =============== HERE WE IMPORT THE DATA =======================
# These are the rewards and trajectory files generated by the
# button_food.py script.
# NOTE: You should manually change these two variable with the
#       path to the newly generated file to see the results.
rewards_filepath = 'data/precomputed_rewards.pkl'
traject_filepath = 'data/precomputed_best_traj.pkl'

with open (rewards_filepath, 'rb') as f:
    par, data = pickle.load (f)
with open (traject_filepath, 'rb') as f:
    (rank, idx), traj = pickle.load (f)

print (type(data), type(traj))
ranks = np.sort(list(data.keys()))

trainset    = np.array (par['trainset'])
validset    = np.array (par['validset'])
testset     = np.linspace (*par['testset'])

train_theta = trainset * np.pi / 180.
valid_theta = validset * np.pi / 180.
test_theta  = testset  * np.pi / 180.

# ======== HERE WE COMPOSE THE FIGURE ===============
fig, axes = plt.subplots (ncols = 3, figsize = (13, 4))

fig, axes[0] = BlockB(fig, axes[0], traj, data[rank][idx])
fig, axes[1] = BlockC(fig, axes[1])
fig, axes[2] = BlockD(fig, axes[2])

fig.tight_layout()

savename = 'fig/Figure_4'

fig.savefig (savename + '.png', dpi = 300)
fig.savefig (savename + '.eps', dpi = 300)
plt.show()